{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "934b74e1",
   "metadata": {},
   "source": [
    "# KDD'21 DLG4NLP Tutorial Demo: Semantic Parsing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4346fd68",
   "metadata": {},
   "source": [
    "### Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "767c3b2e",
   "metadata": {},
   "source": [
    "In this tutorial demo, we will use the Graph4NLP library to build a GNN-based semantic parsing model. The model consists of \n",
    "- graph construction module (e.g., node embedding based dynamic graph)\n",
    "- graph embedding module (e.g., Bi-Sep GAT)\n",
    "- predictoin module (e.g., RNN decoder with attention, copy and coverage mechanisms)\n",
    "\n",
    "We will use the built-in Graph2Seq model APIs to build the model, and evaluate it on the Jobs dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac004ea7",
   "metadata": {},
   "source": [
    "### Environment setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1de88301",
   "metadata": {},
   "source": [
    "Please follow the instructions [here](https://github.com/graph4ai/graph4nlp_demo#environment-setup) to set up the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "086ca306",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from graph4nlp.pytorch.datasets.jobs import JobsDataset\n",
    "from graph4nlp.pytorch.models.graph2seq import Graph2Seq\n",
    "from graph4nlp.pytorch.models.graph2seq_loss import Graph2SeqLoss\n",
    "from graph4nlp.pytorch.modules.graph_construction import *\n",
    "from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase\n",
    "from graph4nlp.pytorch.modules.utils.config_utils import update_values\n",
    "from graph4nlp.pytorch.modules.utils.copy_utils import prepare_ext_vocab\n",
    "from graph4nlp.pytorch.modules.config import get_basic_args"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf1b0d6b",
   "metadata": {},
   "source": [
    "### Build the model handler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c03b7543",
   "metadata": {},
   "source": [
    "Let's build a model handler which will do a bunch of things including setting up dataloader, model, optimizer, evaluation metrics, train/val/test loops, and so on.\n",
    "\n",
    "We will call the Graph2Seq model API which implements a GNN-based encoder and LSTM-based decoder.\n",
    "\n",
    "When setting up the dataloader, users will need to call the dataset API which will preprocess the data, e.g., calling the graph construction module, building the vocabulary, tensorizing the data. Users will need to specify the graph construction type when calling the dataset API.\n",
    "\n",
    "Users can build their customized dataset APIs by inheriting our low-level dataset APIs. We provide low-level dataset APIs to support various scenarios (e.g., `Text2Label`, `Sequence2Labeling`, `Text2Text`, `Text2Tree`, `DoubleText2Text`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b1637d23",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Jobs:\n",
    "    def __init__(self, opt):\n",
    "        super(Jobs, self).__init__()\n",
    "        self.opt = opt\n",
    "        self.use_copy = self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"use_copy\"]\n",
    "        self.use_coverage = self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"use_coverage\"]\n",
    "        self._build_device(self.opt)\n",
    "        self._build_dataloader()\n",
    "        self._build_model()\n",
    "        self._build_loss_function()\n",
    "        self._build_optimizer()\n",
    "        self._build_evaluation()\n",
    "\n",
    "    def _build_device(self, opt):\n",
    "        seed = opt[\"seed\"]\n",
    "        np.random.seed(seed)\n",
    "        if opt[\"use_gpu\"] != 0 and torch.cuda.is_available():\n",
    "            print('[ Using CUDA ]')\n",
    "            torch.manual_seed(seed)\n",
    "            torch.cuda.manual_seed_all(seed)\n",
    "            from torch.backends import cudnn\n",
    "            cudnn.benchmark = True\n",
    "            device = torch.device('cuda' if opt[\"gpu\"] < 0 else 'cuda:%d' % opt[\"gpu\"])\n",
    "        else:\n",
    "            print('[ Using CPU ]')\n",
    "            device = torch.device('cpu')\n",
    "        self.device = device\n",
    "\n",
    "    def _build_dataloader(self):\n",
    "        if self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"dependency\":\n",
    "            topology_builder = DependencyBasedGraphConstruction\n",
    "            graph_type = 'static'\n",
    "            dynamic_init_topology_builder = None\n",
    "        elif self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"constituency\":\n",
    "            topology_builder = ConstituencyBasedGraphConstruction\n",
    "            graph_type = 'static'\n",
    "            dynamic_init_topology_builder = None\n",
    "        elif self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"node_emb\":\n",
    "            topology_builder = NodeEmbeddingBasedGraphConstruction\n",
    "            graph_type = 'dynamic'\n",
    "            dynamic_init_topology_builder = None\n",
    "        elif self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"node_emb_refined\":\n",
    "            topology_builder = NodeEmbeddingBasedRefinedGraphConstruction\n",
    "            graph_type = 'dynamic'\n",
    "            dynamic_init_graph_type = self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\n",
    "                \"dynamic_init_graph_type\"]\n",
    "            if dynamic_init_graph_type is None or dynamic_init_graph_type == 'line':\n",
    "                dynamic_init_topology_builder = None\n",
    "            elif dynamic_init_graph_type == 'dependency':\n",
    "                dynamic_init_topology_builder = DependencyBasedGraphConstruction\n",
    "            elif dynamic_init_graph_type == 'constituency':\n",
    "                dynamic_init_topology_builder = ConstituencyBasedGraphConstruction\n",
    "            else:\n",
    "                raise RuntimeError('Define your own dynamic_init_topology_builder')\n",
    "        else:\n",
    "            raise NotImplementedError(\"Define your topology builder.\")\n",
    "\n",
    "        \n",
    "        # Call the TREC dataset API\n",
    "        dataset = JobsDataset(root_dir=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"root_dir\"],\n",
    "                              pretrained_word_emb_name=self.opt[\"pretrained_word_emb_name\"],\n",
    "                              pretrained_word_emb_cache_dir=self.opt[\"pretrained_word_emb_cache_dir\"],\n",
    "                              merge_strategy=self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\n",
    "                                  \"merge_strategy\"],\n",
    "                              edge_strategy=self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\n",
    "                                  \"edge_strategy\"],\n",
    "                              seed=self.opt[\"seed\"],\n",
    "                              word_emb_size=self.opt[\"word_emb_size\"],\n",
    "                              share_vocab=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"share_vocab\"],\n",
    "                              graph_type=graph_type,\n",
    "                              topology_builder=topology_builder,\n",
    "                              topology_subdir=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"topology_subdir\"],\n",
    "                              thread_number=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"thread_number\"],\n",
    "                              dynamic_graph_type=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"graph_type\"],\n",
    "                              dynamic_init_topology_builder=dynamic_init_topology_builder,\n",
    "                              dynamic_init_topology_aux_args=None)\n",
    "\n",
    "        self.train_dataloader = DataLoader(dataset.train, batch_size=self.opt[\"batch_size\"], shuffle=True,\n",
    "                                           num_workers=1,\n",
    "                                           collate_fn=dataset.collate_fn)\n",
    "        self.test_dataloader = DataLoader(dataset.test, batch_size=self.opt[\"batch_size\"], shuffle=False, num_workers=1,\n",
    "                                          collate_fn=dataset.collate_fn)\n",
    "        self.vocab = dataset.vocab_model\n",
    "\n",
    "    def _build_model(self):\n",
    "        # Call the Graph2Seq model API\n",
    "        self.model = Graph2Seq.from_args(self.opt, self.vocab).to(self.device)\n",
    "\n",
    "    def _build_loss_function(self):\n",
    "        # Call the Graph2Seq loss API\n",
    "        self.loss = Graph2SeqLoss(ignore_index=self.vocab.out_word_vocab.PAD,\n",
    "                                  use_coverage=self.use_coverage, coverage_weight=0.3)\n",
    "    def _build_optimizer(self):\n",
    "        parameters = [p for p in self.model.parameters() if p.requires_grad]\n",
    "        self.optimizer = optim.Adam(parameters, lr=self.opt[\"learning_rate\"])\n",
    "\n",
    "    def _build_evaluation(self):\n",
    "        self.metrics = [ExpressionAccuracy()]\n",
    "\n",
    "    def train(self):\n",
    "        max_score = -1\n",
    "        self._best_epoch = -1\n",
    "        for epoch in range(self.opt[\"epochs\"]):\n",
    "            self.model.train()\n",
    "            self.train_epoch(epoch, split=\"train\")\n",
    "            self._adjust_lr(epoch)\n",
    "            if epoch >= 0:\n",
    "                score = self.evaluate(split=\"test\")\n",
    "                if score >= max_score:\n",
    "                    print(\"Best model saved, epoch {}\".format(epoch))\n",
    "                    self.save_checkpoint(\"best.pth\")\n",
    "                    self._best_epoch = epoch\n",
    "                max_score = max(max_score, score)\n",
    "            if self._stop_condition(epoch):\n",
    "                break\n",
    "        return max_score\n",
    "\n",
    "    def _stop_condition(self, epoch, patience=20):\n",
    "        return epoch > patience + self._best_epoch\n",
    "\n",
    "    def _adjust_lr(self, epoch):\n",
    "        def set_lr(optimizer, decay_factor):\n",
    "            for group in optimizer.param_groups:\n",
    "                group['lr'] = group['lr'] * decay_factor\n",
    "\n",
    "        epoch_diff = epoch - self.opt[\"lr_start_decay_epoch\"]\n",
    "        if epoch_diff >= 0 and epoch_diff % self.opt[\"lr_decay_per_epoch\"] == 0:\n",
    "            if self.opt[\"learning_rate\"] > self.opt[\"min_lr\"]:\n",
    "                set_lr(self.optimizer, self.opt[\"lr_decay_rate\"])\n",
    "                self.opt[\"learning_rate\"] = self.opt[\"learning_rate\"] * self.opt[\"lr_decay_rate\"]\n",
    "                print(\"Learning rate adjusted: {:.5f}\".format(self.opt[\"learning_rate\"]))\n",
    "\n",
    "    def train_epoch(self, epoch, split=\"train\"):\n",
    "        assert split in [\"train\"]\n",
    "        print(\"Start training in split {}, Epoch: {}\".format(split, epoch))\n",
    "        loss_collect = []\n",
    "        dataloader = self.train_dataloader\n",
    "        step_all_train = len(dataloader)\n",
    "        for step, data in enumerate(dataloader):\n",
    "            graph, tgt, gt_str = data[\"graph_data\"], data[\"tgt_seq\"], data[\"output_str\"]\n",
    "            graph = graph.to(self.device)\n",
    "            tgt = tgt.to(self.device)\n",
    "            oov_dict = None\n",
    "            if self.use_copy:\n",
    "                oov_dict, tgt = prepare_ext_vocab(graph, self.vocab, gt_str=gt_str, device=self.device)\n",
    "\n",
    "            prob, enc_attn_weights, coverage_vectors = self.model(graph, tgt, oov_dict=oov_dict)\n",
    "            loss = self.loss(logits=prob, label=tgt, enc_attn_weights=enc_attn_weights,\n",
    "                             coverage_vectors=coverage_vectors)\n",
    "            loss_collect.append(loss.item())\n",
    "            if step % self.opt[\"loss_display_step\"] == 0 and step != 0:\n",
    "                print(\"Epoch {}: [{} / {}] loss: {:.3f}\".format(epoch, step, step_all_train,\n",
    "                                                                           np.mean(loss_collect)))\n",
    "                loss_collect = []\n",
    "            self.optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "\n",
    "    def evaluate(self, split=\"val\"):\n",
    "        self.model.eval()\n",
    "        pred_collect = []\n",
    "        gt_collect = []\n",
    "        assert split in [\"val\", \"test\"]\n",
    "        dataloader = self.val_dataloader if split == \"val\" else self.test_dataloader\n",
    "        for data in dataloader:\n",
    "            graph, tgt, gt_str = data[\"graph_data\"], data[\"tgt_seq\"], data[\"output_str\"]\n",
    "            graph = graph.to(self.device)\n",
    "            if self.use_copy:\n",
    "                oov_dict = prepare_ext_vocab(batch_graph=graph, vocab=self.vocab, device=self.device)\n",
    "                ref_dict = oov_dict\n",
    "            else:\n",
    "                oov_dict = None\n",
    "                ref_dict = self.vocab.out_word_vocab\n",
    "\n",
    "            prob, _, _ = self.model(graph, oov_dict=oov_dict)\n",
    "            pred = prob.argmax(dim=-1)\n",
    "\n",
    "            pred_str = wordid2str(pred.detach().cpu(), ref_dict)\n",
    "            pred_collect.extend(pred_str)\n",
    "            gt_collect.extend(gt_str)\n",
    "\n",
    "        score = self.metrics[0].calculate_scores(ground_truth=gt_collect, predict=pred_collect)\n",
    "        print(\"Evaluation accuracy in `{}` split: {:.3f}\".format(split, score))\n",
    "        return score\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def translate(self):\n",
    "        self.model.eval()\n",
    "\n",
    "        pred_collect = []\n",
    "        gt_collect = []\n",
    "        dataloader = self.test_dataloader\n",
    "        for data in dataloader:\n",
    "            graph, tgt, gt_str = data[\"graph_data\"], data[\"tgt_seq\"], data[\"output_str\"]\n",
    "            graph = graph.to(self.device)\n",
    "            if self.use_copy:\n",
    "                oov_dict = prepare_ext_vocab(batch_graph=graph, vocab=self.vocab, device=self.device)\n",
    "                ref_dict = oov_dict\n",
    "            else:\n",
    "                oov_dict = None\n",
    "                ref_dict = self.vocab.out_word_vocab\n",
    "\n",
    "            pred = self.model.translate(batch_graph=graph, oov_dict=oov_dict, beam_size=4, topk=1)\n",
    "\n",
    "            pred_ids = pred[:, 0, :]  # we just use the top-1\n",
    "\n",
    "            pred_str = wordid2str(pred_ids.detach().cpu(), ref_dict)\n",
    "\n",
    "            pred_collect.extend(pred_str)\n",
    "            gt_collect.extend(gt_str)\n",
    "\n",
    "        score = self.metrics[0].calculate_scores(ground_truth=gt_collect, predict=pred_collect)\n",
    "        return score\n",
    "\n",
    "    def load_checkpoint(self, checkpoint_name):\n",
    "        checkpoint_path = os.path.join(self.opt[\"checkpoint_save_path\"], checkpoint_name)\n",
    "        self.model.load_state_dict(torch.load(checkpoint_path))\n",
    "\n",
    "    def save_checkpoint(self, checkpoint_name):\n",
    "        checkpoint_path = os.path.join(self.opt[\"checkpoint_save_path\"], checkpoint_name)\n",
    "        if not os.path.exists(self.opt[\"checkpoint_save_path\"]):\n",
    "            os.makedirs(self.opt[\"checkpoint_save_path\"], exist_ok=True)\n",
    "        torch.save(self.model.state_dict(), checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dac6e553",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExpressionAccuracy(EvaluationMetricBase):\n",
    "    def __init__(self):\n",
    "        super(ExpressionAccuracy, self).__init__()\n",
    "\n",
    "    def calculate_scores(self, ground_truth, predict):\n",
    "        correct = 0\n",
    "        assert len(ground_truth) == len(predict)\n",
    "        for gt, pred in zip(ground_truth, predict):\n",
    "            if gt == pred:\n",
    "                correct += 1.\n",
    "        return correct / len(ground_truth)\n",
    "\n",
    "def wordid2str(word_ids, vocab):\n",
    "    ret = []\n",
    "    assert len(word_ids.shape) == 2, print(word_ids.shape)\n",
    "    for i in range(word_ids.shape[0]):\n",
    "        id_list = word_ids[i, :]\n",
    "        ret_inst = []\n",
    "        for j in range(id_list.shape[0]):\n",
    "            if id_list[j] == vocab.EOS:\n",
    "                break\n",
    "            token = vocab.getWord(id_list[j])\n",
    "            ret_inst.append(token)\n",
    "        ret.append(\" \".join(ret_inst))\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ebf198e",
   "metadata": {},
   "source": [
    "### Set up the config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a702636a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config setup\n",
    "config_file = '../config/jobs/gat_bi_sep_dynamic_node_emb.yaml'\n",
    "config = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader)\n",
    "\n",
    "opt = get_basic_args(graph_construction_name=config[\"graph_construction_name\"],\n",
    "                              graph_embedding_name=config[\"graph_embedding_name\"],\n",
    "                              decoder_name=config[\"decoder_name\"])\n",
    "update_values(to_args=opt, from_args_list=[config, config[\"other_args\"]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "985bc54a",
   "metadata": {},
   "source": [
    "### Run the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cb15c2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ Using CPU ]\n",
      "Start training in split train, Epoch: 0\n",
      "Epoch 0: [10 / 21] loss: 3.938\n",
      "Epoch 0: [20 / 21] loss: 2.506\n",
      "Evaluation accuracy in `test` split: 0.000\n",
      "Best model saved, epoch 0\n",
      "Start training in split train, Epoch: 1\n",
      "Epoch 1: [10 / 21] loss: 1.845\n",
      "Epoch 1: [20 / 21] loss: 1.487\n",
      "Evaluation accuracy in `test` split: 0.000\n",
      "Best model saved, epoch 1\n",
      "Start training in split train, Epoch: 2\n",
      "Epoch 2: [10 / 21] loss: 1.198\n",
      "Epoch 2: [20 / 21] loss: 1.104\n",
      "Evaluation accuracy in `test` split: 0.100\n",
      "Best model saved, epoch 2\n",
      "Start training in split train, Epoch: 3\n",
      "Epoch 3: [10 / 21] loss: 1.015\n",
      "Epoch 3: [20 / 21] loss: 0.982\n",
      "Evaluation accuracy in `test` split: 0.071\n",
      "Start training in split train, Epoch: 4\n",
      "Epoch 4: [10 / 21] loss: 0.925\n",
      "Epoch 4: [20 / 21] loss: 0.873\n",
      "Evaluation accuracy in `test` split: 0.186\n",
      "Best model saved, epoch 4\n",
      "Start training in split train, Epoch: 5\n",
      "Epoch 5: [10 / 21] loss: 0.869\n",
      "Epoch 5: [20 / 21] loss: 0.826\n",
      "Evaluation accuracy in `test` split: 0.314\n",
      "Best model saved, epoch 5\n",
      "Start training in split train, Epoch: 6\n",
      "Epoch 6: [10 / 21] loss: 0.780\n",
      "Epoch 6: [20 / 21] loss: 0.815\n",
      "Evaluation accuracy in `test` split: 0.236\n",
      "Start training in split train, Epoch: 7\n",
      "Epoch 7: [10 / 21] loss: 0.758\n",
      "Epoch 7: [20 / 21] loss: 0.693\n",
      "Evaluation accuracy in `test` split: 0.386\n",
      "Best model saved, epoch 7\n",
      "Start training in split train, Epoch: 8\n",
      "Epoch 8: [10 / 21] loss: 0.702\n",
      "Epoch 8: [20 / 21] loss: 0.663\n",
      "Evaluation accuracy in `test` split: 0.471\n",
      "Best model saved, epoch 8\n",
      "Start training in split train, Epoch: 9\n",
      "Epoch 9: [10 / 21] loss: 0.645\n",
      "Epoch 9: [20 / 21] loss: 0.619\n",
      "Evaluation accuracy in `test` split: 0.507\n",
      "Best model saved, epoch 9\n",
      "Start training in split train, Epoch: 10\n",
      "Epoch 10: [10 / 21] loss: 0.635\n"
     ]
    }
   ],
   "source": [
    "# run the model\n",
    "runner = Jobs(opt)\n",
    "max_score = runner.train()\n",
    "print(\"Train finish, best val score: {:.3f}\".format(max_score))\n",
    "runner.load_checkpoint(\"best.pth\")\n",
    "# runner.evaluate(\"test\")\n",
    "test_score = runner.translate()\n",
    "print('Final test accuracy: {}'.format(test_score))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
